# core/formalization/knowledge_graph.py
import os
import json
import numpy as np
import faiss
from typing import Dict, List, Any, Optional, Tuple, Set
import uuid
from collections import defaultdict
import datetime
from enum import Enum

from utils.logger import Logger
import core.agent_prompt as AgentPrompt
from utils.json_utils import extract_json
from llm.auxiliary import Auxiliary
from core.formalization.symbol_manager import SymbolManager
from llm.message import Message, MessageContent, ROLE_USER, TYPE_CONTENT

class FormalizationKnowledgeGraph:
    def __init__(self, logger: Logger, auxiliary: Auxiliary, symbol_manager: SymbolManager, config: Dict = {}):
        self.logger = logger
        self.auxiliary = auxiliary
        self.symbol_manager = symbol_manager
        self.config = config
        
        self.nodes = {}
        self.edges = defaultdict(list)
        self.term_to_node = {}
        
        self.embedding_dim = config.get("embedding_dim", 1024)
        self.index = None
        self.node_ids = []
        
        self._load_knowledge_graph()
        self._init_faiss_index()
    
    def _load_knowledge_graph(self):
        kg_file = self._get_kg_filepath()
        if not os.path.exists(kg_file):
            return
        
        try:
            with open(kg_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
                self.nodes = data.get('nodes', {})
                self.edges = defaultdict(list, data.get('edges', {}))
                self.term_to_node = data.get('term_to_node', {})
        except Exception as e:
            self.logger.log_exception(e)
    
    def _save_knowledge_graph(self):
        kg_file = self._get_kg_filepath()
        
        try:
            os.makedirs(os.path.dirname(kg_file), exist_ok=True)

            edges_dict = {k: v for k, v in self.edges.items()}
            data = {
                'nodes': self.nodes,
                'edges': edges_dict,
                'term_to_node': self.term_to_node
            }
            
            with open(kg_file, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
        except Exception as e:
            self.logger.log_exception(e)
    
    def _get_kg_filepath(self):
        cache_dir = self.config.get('cache_dir', None)
        if not cache_dir:
            raise ValueError("Unknown cache dir")
        
        kg_dir = os.path.join(cache_dir, 'knowledge_graph')
        os.makedirs(kg_dir, exist_ok=True)
        
        return os.path.join(kg_dir, "formalization_kg.json")
    
    def _get_embedding_filepath(self):
        cache_dir = self.config.get('cache_dir', None)
        if not cache_dir:
            raise ValueError("Unknown cache dir")
        
        kg_dir = os.path.join(cache_dir, 'knowledge_graph')
        os.makedirs(kg_dir, exist_ok=True)
        
        return os.path.join(kg_dir, "embeddings.npy")
    
    def _init_faiss_index(self):
        self.index = faiss.IndexFlatL2(self.embedding_dim)
        
        embedding_file = self._get_embedding_filepath()
        if os.path.exists(embedding_file):
            try:
                embeddings = np.load(embedding_file)
                node_ids_file = embedding_file.replace('.npy', '_ids.json')
                
                with open(node_ids_file, 'r', encoding='utf-8') as f:
                    self.node_ids = json.load(f)
                
                if len(embeddings) > 0:
                    self.index.add(embeddings)
                    self.logger.info(f"Loaded {len(embeddings)} embeddings from file")
            except Exception as e:
                self.logger.log_exception(e)
                self.index = faiss.IndexFlatL2(self.embedding_dim)
                self.node_ids = []
    
    def _save_embeddings(self):
        if len(self.node_ids) == 0:
            return
        
        try:
            embedding_file = self._get_embedding_filepath()
            num_vectors = self.index.ntotal
            embeddings = np.zeros((num_vectors, self.embedding_dim), dtype=np.float32)
            
            for i in range(num_vectors):
                embeddings[i] = self.index.reconstruct(i)
            
            np.save(embedding_file, embeddings)
            
            node_ids_file = embedding_file.replace('.npy', '_ids.json')
            with open(node_ids_file, 'w', encoding='utf-8') as f:
                json.dump(self.node_ids, f)
            
            self.logger.info(f"Saved {num_vectors} embeddings to file")
        except Exception as e:
            self.logger.log_exception(e)
    
    def _get_embedding(self, text: str) -> np.ndarray:
        try:
            embedding = self.auxiliary.api_embedding(text)
            return np.array(embedding, dtype=np.float32).reshape(1, -1)
        except Exception as e:
            self.logger.log_exception(e)
            return np.zeros((1, self.embedding_dim), dtype=np.float32)
    
    def extract_and_update_knowledge(self, query: str, response: str, category: str) -> List[str]:
        extracted_info = self._extract_knowledge(query, response, category)
        updated_node_ids = []
        
        for term_info in extracted_info.get('terms', []):
            term = term_info.get('term')
            if not term:
                continue
                
            node_id = self._add_or_update_term(term_info, category)
            updated_node_ids.append(node_id)
        
        for relation in extracted_info.get('relations', []):
            source_term = relation.get('source')
            target_term = relation.get('target')
            relation_type = relation.get('type')
            
            if not source_term or not target_term or not relation_type:
                continue
                
            self._add_relation(source_term, target_term, relation_type)
        
        self._save_knowledge_graph()
        self._save_embeddings()
        
        return updated_node_ids
    
    def _add_or_update_term(self, term_info: Dict, category: str) -> str:
        term = term_info.get('term')
        
        if term in self.term_to_node:
            node_id = self.term_to_node[term]
            node = self.nodes[node_id]
            
            if not node.get('definition') and 'definition' in term_info:
                node['definition'] = term_info.get('definition')
            
            if 'representations' in term_info:
                for rep_type, rep_value in term_info.get('representations', {}).items():
                    if rep_type not in node['representations']:
                        node['representations'][rep_type] = rep_value
                    elif rep_type == 'symbolic' and 'symbolic' not in node['representations']:
                        node['representations']['symbolic'] = rep_value
            
            if 'synonyms' in term_info:
                node['synonyms'].extend(term_info.get('synonyms', []))
                node['synonyms'] = list(set(node['synonyms']))
                
                for synonym in term_info.get('synonyms', []):
                    self.term_to_node[synonym] = node_id
            
            self.nodes[node_id] = node
            self._update_node_embedding(node_id, node)
            
            return node_id
        else:
            node_id = str(uuid.uuid4())
            
            node = {
                'id': node_id,
                'term': term,
                'category': category,
                'definition': term_info.get('definition', ''),
                'representations': term_info.get('representations', {}),
                'synonyms': term_info.get('synonyms', []),
                'metadata': {
                    'created_at': datetime.datetime.now().isoformat(),
                    'confidence': term_info.get('confidence', 0.8)
                }
            }
            
            self.nodes[node_id] = node
            self.term_to_node[term] = node_id
            for synonym in term_info.get('synonyms', []):
                self.term_to_node[synonym] = node_id
            
            self._add_node_embedding(node_id, node)
            
            return node_id

    def _add_node_embedding(self, node_id: str, node: Dict):
        node_text = f"{node['term']} {node['definition']} {' '.join(node['synonyms'])}"
        
        for _, rep_value in node.get('representations', {}).items():
            if isinstance(rep_value, str):
                node_text += f" {rep_value}"
        
        embedding = self._get_embedding(node_text)
        
        self.index.add(embedding)
        self.node_ids.append(node_id)

    def _update_node_embedding(self, node_id: str, node: Dict):
        if node_id not in self.node_ids:
            self._add_node_embedding(node_id, node)
            return
        
        idx = self.node_ids.index(node_id)
        node_text = f"{node['term']} {node['definition']} {' '.join(node['synonyms'])}"

        for _, rep_value in node.get('representations', {}).items():
            if isinstance(rep_value, str):
                node_text += f" {rep_value}"
        
        embedding = self._get_embedding(node_text)

        all_embeddings = []
        for i in range(self.index.ntotal):
            if i == idx:
                all_embeddings.append(embedding[0])
            else:
                all_embeddings.append(self.index.reconstruct(i))
        
        self.index = faiss.IndexFlatL2(self.embedding_dim)
        
        if all_embeddings:
            self.index.add(np.array(all_embeddings))
    
    def _add_relation(self, source_term: str, target_term: str, relation_type: str) -> bool:
        if source_term not in self.term_to_node or target_term not in self.term_to_node:
            return False
        
        source_id = self.term_to_node[source_term]
        target_id = self.term_to_node[target_term]
        
        if target_id in [edge[0] for edge in self.edges[source_id] if edge[1] == relation_type]:
            return False
        
        self.edges[source_id].append((target_id, relation_type))
        
        return True
    
    def _extract_knowledge(self, query: str, response: str, category: str) -> Dict:
        try:
            prompt = AgentPrompt.extract_knowledge(query, response, category)
            messages = [Message(ROLE_USER, [MessageContent(TYPE_CONTENT, prompt)])]
            response = self.auxiliary.get_api_generate_model().generate(messages)
            knowledge_info = extract_json(response)
            return knowledge_info
        except Exception as e:
            self.logger.log_exception(e)
            return {
                "terms": [],
                "relations": []
            }
        
    def enhance_query_with_knowledge(self, query: str, knowledge_nodes: List[Dict]) -> str:
        if not knowledge_nodes:
            return query
        
        enhanced_query = query
        enhanced_query += "\n\nRelevant context:\n"

        for i, node in enumerate(knowledge_nodes):
            enhanced_query += f"\n{i+1}. Term: {node.get('term', '')}\n"

            definition = node.get('definition', '')
            if definition:
                enhanced_query += f"   Definition: {definition}\n"

            representations = node.get('representations', {})
            if representations:
                enhanced_query += f"   Representations:\n"
                for rep_type, rep_value in representations.items():
                    enhanced_query += f"     - {rep_type}: {rep_value}\n"

            synonyms = node.get('synonyms', [])
            if synonyms:
                enhanced_query += f"   Synonyms: {', '.join(synonyms)}\n"

        enhanced_query += "\n\nPlease use the above relevant context to help with answering the original query."
        
        self.logger.info(f"Enhanced query [{query}] result:\n{enhanced_query}")
        return enhanced_query

    def _calculate_relevance(self, query: str, node: Dict) -> float:
        query_terms = set(query.lower().split())
        node_term = node.get('term', '').lower()
        node_synonyms = [s.lower() for s in node.get('synonyms', [])]
        node_definition = node.get('definition', '').lower().split()
        
        all_node_terms = set([node_term] + node_synonyms + node_definition)
        
        # Jaccard similarity
        intersection = len(query_terms.intersection(all_node_terms))
        union = len(query_terms.union(all_node_terms))
        
        term_score = intersection / union if union > 0 else 0
        query_embedding = self._get_embedding(query)
        
        node_text = f"{node.get('term', '')} {node.get('definition', '')} {' '.join(node.get('synonyms', []))}"
        
        for _, rep_value in node.get('representations', {}).items():
            if isinstance(rep_value, str):
                node_text += f" {rep_value}"
        
        node_embedding = self._get_embedding(node_text)
        
        # Cosine similarity
        dot_product = np.sum(query_embedding * node_embedding)
        query_norm = np.sqrt(np.sum(query_embedding ** 2))
        node_norm = np.sqrt(np.sum(node_embedding ** 2))
        
        embedding_score = dot_product / (query_norm * node_norm) if query_norm * node_norm > 0 else 0
        
        relevance = (term_score + embedding_score) / 2
        confidence = node.get('metadata', {}).get('confidence', 0.5)
        return relevance * confidence
    
    def search_knowledge(self, query: str, top_k: int = 5) -> List[Dict]:
        key_terms = self._extract_key_terms(query)

        candidate_nodes = set()
        for term in key_terms:
            if term in self.term_to_node:
                candidate_nodes.add(self.term_to_node[term])
        
        vector_results = self._vector_search(query, top_k)
        
        combined_results = list(candidate_nodes)
        for node_id in vector_results:
            if node_id not in combined_results:
                combined_results.append(node_id)
        
        scored_nodes = []
        for node_id in combined_results:
            node = self.nodes.get(node_id)
            if node:
                score = self._calculate_relevance(query, node)
                scored_nodes.append((node_id, score))
        
        scored_nodes.sort(key=lambda x: x[1], reverse=True)
        top_nodes = [self.nodes[node_id] for node_id, _ in scored_nodes[:top_k]]
        
        return top_nodes
    
    def _extract_key_terms(self, query: str) -> List[str]:
        all_terms = list(self.term_to_node.keys())
        query_terms = []

        prompt = AgentPrompt.extract_key_term(query, all_terms)
        messages = [Message(ROLE_USER, [MessageContent(TYPE_CONTENT, prompt)])]
        
        try:
            response = self.auxiliary.get_api_generate_model().generate(messages)
            data = extract_json(response)
            matched_terms = data.get("matched_terms", [])
            new_terms = data.get("new_terms", [])
            
            all_extracted_terms = matched_terms + new_terms
            self.logger.info(f"Extracted terms: {all_extracted_terms}")
            return all_extracted_terms

        except Exception as e:
            self.logger.log_exception(e)
            return []
    
    def _vector_search(self, query: str, top_k: int) -> List[str]:
        if self.index.ntotal == 0:
            return []
        
        query_embedding = self._get_embedding(query)
        k = min(top_k, self.index.ntotal)
        distances, indices = self.index.search(query_embedding, k)
        result_ids = [self.node_ids[idx] for idx in indices[0]]
        return result_ids
    
    def get_term_by_id(self, node_id: str) -> Optional[Dict]:
        return self.nodes.get(node_id)
    
    def get_term_by_name(self, term: str) -> Optional[Dict]:
        if term in self.term_to_node:
            node_id = self.term_to_node[term]
            return self.nodes.get(node_id)
        return None
    
    def get_related_terms(self, term: str, max_depth: int = 2) -> List[Dict]:
        if term not in self.term_to_node:
            return []
        
        node_id = self.term_to_node[term]
        
        # BFS to search related nodes
        visited = set([node_id])
        queue = [(node_id, 0)]  # (node_id, depth)
        related_nodes = []
        
        while queue:
            current_id, depth = queue.pop(0)
            
            if depth >= max_depth:
                continue
            
            for neighbor_id, relation_type in self.edges.get(current_id, []):
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    queue.append((neighbor_id, depth + 1))
                    
                    node = self.nodes.get(neighbor_id)
                    if node:
                        related_nodes.append(node)
        
        return related_nodes
